Okay, well, so let's get started.
Last time what I introduced to you was the Jax package as one of those frameworks that
are really useful to do machine learning without going too much into the technical details
or not having to program them yourself.
And we went through a few code examples.
So I now want to point out that you can actually find, I made sure to put a kind of Jax cheat
sheet on the website.
So let me just switch to the website.
So this is the lecture website and you can find a very concise, well, not even introduction
but rather a collection of code snippets that you can readily use.
So these are partially taken from the examples that I showed you and some of the examples
that I will show you.
And whenever in the future you are going to define your neural networks and train them
using Jax, you can basically just copy and paste from here and I made sure to keep it
as simple as possible.
So you will get a chance to go through all of this next Monday in the tutorial, but I
nevertheless want to show you a few further elements before we go on in the lecture course
discussing then image recognition and convolutional neural networks.
So this is, I believe, the notebook we had been looking at last time.
We defined a neural network, we defined a cost function and we knew how to take the
gradient of the cost function with respect to the parameters.
And so on the website you will find further links.
So let me just try to bring them up in a better way.
You will find the links to this notebook and then there are more interesting things
to learn about Jax.
So this again is on the GitHub and what you can see in this notebook is whatever we discussed
before.
But then Jax offers a few extra tricks.
So one of the ideas is that you can actually compile code just in time, which means that
even though Python is an interpreted language, it can give you very performant code and also
it means that maybe you don't want to run it on a CPU but on a graphical processing
unit and it takes care of all of that automatically.
And the way to do this is super, super simple.
So all you have to do is you place this little modifier or it is actually called a decorator
in front of the function and then it knows how to compile it.
JIT stands for just in time compilation.
And all you see is simply that whatever code you decorate in this way will run significantly
faster than if you hadn't done it in this way.
So that's one way to just in time compile code and the other way is to apply the JIT
function that you import from Jax like a function in the same way that you apply the gradient
function to a function and then the return of that is a JIT function.
So here for example I took the cost function, I told Jax that I want the gradient of the
with regard to the first argument, that's what this says, and then I also say to Jax
that please I want the compiled version of that whole thing and whatever I get I call
grad cost and this is now a function but whenever I call this function I'm actually calling
the compiled version of it.
So this is one of the little tricks that Jax offers us.
The rest is the same as before and there's a little extra piece of information so typically
you don't just want the gradient of the cost function, you also want the value of the cost
function and there's a little convenience routine that gives you both of these things
Presenters
Zugänglich über
Offener Zugang
Dauer
01:24:18 Min
Aufnahmedatum
2024-05-23
Hochgeladen am
2024-05-27 13:39:08
Sprache
en-US